#------------------------------------------------------------------------------
# plot trends over time
# Figure 1. Pain diagnosis
# Figure S1. Enrollment 
# Figure 2. Trends among pain patients
#==============================================================================

load_library = c('bit64','data.table','fst','future.apply','stringr','logger','vroom')
invisible(lapply(load_library, function(x) library(x, character.only=TRUE, quietly= TRUE)))

library(ggplot2)
library(ggpubr)

bucket = '/Users/bk/Dropbox/project/OP/- posted/2021-06-08_covid_substitution_JAMA/dataverse'
data_path = file.path(bucket, 'data')

infile = file.path(data_path,'processed_data','ses_division_weekly_covid_0101_0930.fst')

# ----------
logger::log_info('now loading data ...', infile)

dt = read_fst(infile, as.data.table = TRUE)

dt[, c('year','week') := tstrsplit(date, 'W',fixed=TRUE)]
dt[, week := as.numeric(gsub('W','',week))]
dt[, date := as.Date(paste(year, week, 1, sep="-"), "%Y-%U-%u")]
dt[, month := month(date)]

pre_condition_list = c('all','backpain','pain','ord')
for (var in pre_condition_list){
	dt[[paste0(var,'_','sum_sum_opioid_quantity')]] =  as.numeric(dt[[paste0(var,'_','sum_sum_opioid_quantity')]])
	dt[[paste0(var,'_','sum_sum_opioid_days')]] =  as.numeric(dt[[paste0(var,'_','sum_sum_opioid_days')]])
	dt[[paste0(var,'_','sum_sum_opioid_mme')]] =  as.numeric(dt[[paste0(var,'_','sum_sum_opioid_mme')]])
	dt[[paste0(var,'_','mean_sum_opioid_quantity')]] =  as.numeric(dt[[paste0(var,'_','mean_sum_opioid_quantity')]])
	dt[[paste0(var,'_','mean_sum_opioid_days')]] =  as.numeric(dt[[paste0(var,'_','mean_sum_opioid_days')]])
	dt[[paste0(var,'_','mean_sum_opioid_mme')]] =  as.numeric(dt[[paste0(var,'_','mean_sum_opioid_mme')]])
}

# drop the first week
dt = dt[week > 1, ]

#--------------
# % pain patients
#==============

lci = function(x,w) {
	wmean = weighted.mean(x, w, na.rm=TRUE)
	wsd = sqrt(Hmisc::wtd.var(x, weights=w, na.rm=TRUE))
	N = sum(w) 
	wmean - wsd/sqrt(N)*1.96
}
uci = function(x,w) {
	wmean = weighted.mean(x, w, na.rm=TRUE)
	wsd = sqrt(Hmisc::wtd.var(x, weights=w, na.rm=TRUE))
	N = sum(w) 
	wmean + wsd/sqrt(N)*1.96
}
 
mean_target_var = 'all_mean_targetpain'

mean_dt = dt[, lapply(.SD, weighted.mean, w=all_sum_pain, na.rm=TRUE), by='date', .SDcols=mean_target_var] 
lci_dt = dt[, lapply(.SD, lci, w=all_sum_pain), by='date', .SDcols=mean_target_var] 
uci_dt = dt[, lapply(.SD, uci, w=all_sum_pain), by='date', .SDcols=mean_target_var] 

names(lci_dt) = paste0('lci_',names(lci_dt))
names(uci_dt) = paste0('uci_',names(uci_dt))

mean_dt = merge(mean_dt, lci_dt, by.x='date', by.y=c('lci_date'))
mean_dt = merge(mean_dt, uci_dt, by.x='date', by.y=c('uci_date'))

mean_dt[,date := as.Date(date)]
mean_dt[, year := year(date)]
mean_dt[, date := as.Date(gsub('2019','2020',date))]
mean_dt[,year := factor(year, levels = c(2019,2020))]

label_y = '% Patients with pain diagnosis'

p = ggplot(mean_dt[
	date <= as.Date('2020-09-30'),], aes(x=date,y=get(mean_target_var)*100,group=year,
		color=year,fill=year,
	ymin=get(paste0('lci_',mean_target_var))*100,
	ymax=get(paste0('uci_',mean_target_var))*100)) +
	geom_line() + 
	#geom_ribbon(alpha=0.3)+
	geom_point() + 
	scale_x_date(date_breaks = '1 month', date_labels = '%b')+
	#scale_color_manual(name='weekend',values = c('weekday'='black','weekend'='red'))+
	scale_color_manual(name='year',values = c('2019'='black','2020'='red'))+
	scale_fill_manual(name='year',values = c('2019'='gray','2020'='pink'))+
	theme_bw() +
	ylim(0, 4.5)+
	theme(legend.position = 'top')+
	labs(y=paste0(label_y),x=NULL)+
	annotate("text",color='gray10',
		x=as.Date('2020-03-13'),y=0.5,size=2,label='National Emergency Declaration')+
	annotate(geom="curve", color='gray10',
		x=as.Date('2020-03-13'),y=3.8,linetype='solid',
		xend=as.Date('2020-03-13'),yend=0.55, 
		curvature=0, arrow=arrow(length=unit(1, 'mm')))+
		
	annotate("text",color='gray10',
		x=as.Date('2020-05-31'),y=0.2,size=2,label='Memorial Day')+
	annotate(geom="curve",color='gray10',
		x=as.Date('2020-05-31'),y=2.8,linetype='solid',
		xend=as.Date('2020-05-31'),yend=0.25, 
		curvature=0, arrow=arrow(length=unit(1, 'mm')))+
	
	annotate("text",color='gray10',
		x=as.Date('2020-07-04'),y=0.5,size=2,label='Independence Day')+
	annotate(geom="curve", color='gray10',
		x=as.Date('2020-07-04'),y=3.2,linetype='solid',
		xend=as.Date('2020-07-04'),yend=0.55, 
		curvature=0, arrow=arrow(length=unit(1, 'mm')))+

	annotate("text",color='gray10',
		x=as.Date('2020-09-09'),y=0.2,size=2,label='Labor Day')+
	annotate(geom="curve",color='gray10',
		x=as.Date('2020-09-09'),y=3.6,linetype='solid',
		xend=as.Date('2020-09-09'),yend=0.25, 
		curvature= 0, arrow=arrow(length=unit(1, 'mm')))

ggsave(file.path(data_path,'figure_1_pain_trends.eps'), p, width=5, height=4)

